import matplotlib.pyplot as plt


def draw_lr(lr_list, save_name):
    plt.figure()
    plt.plot(list(range(len(lr_list))), lr_list)
    plt.xlabel('iter')
    plt.ylabel('learning rate')
    plt.savefig(save_name, dpi=200)


def draw_loss(loss_list, save_name, ylabel='loss'):
    plt.figure()
    plt.plot(list(range(len(loss_list))), loss_list)
    plt.xlabel('epoch')
    plt.ylabel(ylabel)
    plt.savefig(save_name, dpi=200)


def draw_acc(acc_list, save_name):
    plt.figure()
    plt.plot(list(range(len(acc_list))), acc_list)
    plt.xlabel('epoch')
    plt.ylabel('accuracy')
    plt.savefig(save_name, dpi=200)
